import torch
import torch.nn as nn

D = 28 * 28
n = 256
classes = 10

fc_model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(D,10),nn.Tanh(),
    *[nn.Sequential(nn.Linear(n,n),nn.Tanh()) for _ in range(1)],
    nn.Linear(n,classes)
)
from torchsummary import summary
summary(fc_model, input_size=(1, 28, 28))
